{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mnist_reader\n",
    "X_train, y_train = mnist_reader.load_mnist('./', kind='train')\n",
    "X_test, y_test = mnist_reader.load_mnist('./', kind='t10k')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from sklearn.tree import DecisionTreeClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# X_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # X_train[36000].reshape(28,28)\n",
    "# %matplotlib inline\n",
    "# import matplotlib\n",
    "# import matplotlib.pyplot as plt\n",
    "# plt.imshow(X_train[36000].reshape(28,28),cmap=matplotlib.cm.binary)\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import numpy as np\n",
    "# shuffle_index = np.random.permutation(60000)\n",
    "# X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# y_train_8 = (y_train == 8)\n",
    "# y_train_8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# y_train_8.reshape(10,-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tree = DecisionTreeClassifier()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tree.fit(X_train,y_train)\n",
    "# y_pred = tree.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sum = 0.0\n",
    "# for i in range(10000):\n",
    "#     if(y_pred[i] == y_test[i]):\n",
    "#         sum = sum+1\n",
    "    \n",
    "# print('Test set score: %f' % (sum/10000.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "knn_clf = KNeighborsClassifier()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# knn_clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from sklearn.model_selection import cross_val_score\n",
    "# y_pred = knn.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cross_val_score(knn,X_train,y_train,cv=3,scoring=\"accuracy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sum = 0.0\n",
    "# for i in range(10000):\n",
    "#     if(y_pred[i] == y_test[i]):\n",
    "#         sum = sum+1\n",
    "    \n",
    "# print('Test set score: %f' % (sum/10000.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/xuzy/opt/anaconda3/lib/python3.7/site-packages/sklearn/feature_extraction/text.py:17: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working\n",
      "  from collections import Mapping, defaultdict\n"
     ]
    }
   ],
   "source": [
    "from sklearn import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# param_grid = [\n",
    "#     {\n",
    "#           'weights':['uniform'],\n",
    "#           'n_neighbors':[i for i in range(1,10)]\n",
    "#     },\n",
    "#     {\n",
    "#           'weights':['distance'],\n",
    "#           'n_neighbors':[i for i in range(1,10)],\n",
    "#           'p':[i for i in range(1,6)]\n",
    "#     }\n",
    "# ]\n",
    "param_grid = [{'weights': [\"uniform\", \"distance\"], 'n_neighbors': [3, 4, 5]}]\n",
    "grid_search = GridSearchCV(knn_clf,param_grid,n_jobs=-1,verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 3 folds for each of 6 candidates, totalling 18 fits\n",
      "[CV] n_neighbors=3, weights=uniform ..................................\n",
      "[CV] n_neighbors=3, weights=uniform ..................................\n",
      "[CV] n_neighbors=3, weights=uniform ..................................\n",
      "[CV] n_neighbors=3, weights=distance .................................\n",
      "[CV] n_neighbors=3, weights=distance .................................\n",
      "[CV] n_neighbors=3, weights=distance .................................\n",
      "[CV] n_neighbors=4, weights=uniform ..................................\n",
      "[CV] n_neighbors=4, weights=uniform ..................................\n",
      "[CV] n_neighbors=4, weights=uniform ..................................\n",
      "[CV] n_neighbors=4, weights=distance .................................\n",
      "[CV] n_neighbors=4, weights=distance .................................\n",
      "[CV] n_neighbors=4, weights=distance .................................\n"
     ]
    }
   ],
   "source": [
    "# %%time\n",
    "grid_search.fit(X_train,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_search.best_estimator_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# knn_clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
